{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jupyter/.local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import re\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import collections\n",
    "from unidecode import unidecode\n",
    "from sklearn.cross_validation import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3], ['SEPARATOR', 4]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words - 1))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary\n",
    "\n",
    "def str_idx(corpus, dic, maxlen, UNK=3):\n",
    "    X = np.zeros((len(corpus),maxlen))\n",
    "    for i in range(len(corpus)):\n",
    "        for no, k in enumerate(corpus[i][:maxlen][::-1]):\n",
    "            val = dic[k] if k in dic else UNK\n",
    "            X[i,-1 - no]= val\n",
    "    return X\n",
    "\n",
    "def cleaning(string):\n",
    "    string = unidecode(string).replace('.', ' . ').replace(',', ' , ')\n",
    "    string = re.sub('[^A-Za-z\\- ]+', ' ', string)\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string.lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>qid1</th>\n",
       "      <th>qid2</th>\n",
       "      <th>question1</th>\n",
       "      <th>question2</th>\n",
       "      <th>is_duplicate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>What is the step by step guide to invest in sh...</td>\n",
       "      <td>What is the step by step guide to invest in sh...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>What is the story of Kohinoor (Koh-i-Noor) Dia...</td>\n",
       "      <td>What would happen if the Indian government sto...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "      <td>How can I increase the speed of my internet co...</td>\n",
       "      <td>How can Internet speed be increased by hacking...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>8</td>\n",
       "      <td>Why am I mentally very lonely? How can I solve...</td>\n",
       "      <td>Find the remainder when [math]23^{24}[/math] i...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>9</td>\n",
       "      <td>10</td>\n",
       "      <td>Which one dissolve in water quikly sugar, salt...</td>\n",
       "      <td>Which fish would survive in salt water?</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  qid1  qid2                                          question1  \\\n",
       "0   0     1     2  What is the step by step guide to invest in sh...   \n",
       "1   1     3     4  What is the story of Kohinoor (Koh-i-Noor) Dia...   \n",
       "2   2     5     6  How can I increase the speed of my internet co...   \n",
       "3   3     7     8  Why am I mentally very lonely? How can I solve...   \n",
       "4   4     9    10  Which one dissolve in water quikly sugar, salt...   \n",
       "\n",
       "                                           question2  is_duplicate  \n",
       "0  What is the step by step guide to invest in sh...             0  \n",
       "1  What would happen if the Indian government sto...             0  \n",
       "2  How can Internet speed be increased by hacking...             0  \n",
       "3  Find the remainder when [math]23^{24}[/math] i...             0  \n",
       "4            Which fish would survive in salt water?             0  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('quora_duplicate_questions.tsv', delimiter='\\t').dropna()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "left, right, label = df['question1'].tolist(), df['question2'].tolist(), df['is_duplicate'].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0, 1]), array([255024, 149263]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(label, return_counts = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 404287/404287 [00:07<00:00, 51783.93it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(len(left))):\n",
    "    left[i] = cleaning(left[i])\n",
    "    right[i] = cleaning(right[i])\n",
    "    left[i] = left[i] + ' SEPARATOR ' + right[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 87662\n",
      "Most common words [['SEPARATOR', 4], ('SEPARATOR', 404287), ('the', 377593), ('what', 324635), ('is', 269934), ('i', 223893)]\n",
      "Sample data [6, 7, 5, 1286, 63, 1286, 2502, 11, 565, 12] ['what', 'is', 'the', 'step', 'by', 'step', 'guide', 'to', 'invest', 'in']\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(left).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def position_encoding(inputs):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    repr_dim = inputs.get_shape()[-1].value\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])\n",
    "\n",
    "def layer_norm(inputs, epsilon=1e-8):\n",
    "    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "    params_shape = inputs.get_shape()[-1:]\n",
    "    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "    return gamma * normalized + beta\n",
    "\n",
    "def cnn_block(x, dilation_rate, pad_sz, hidden_dim, kernel_size):\n",
    "    x = layer_norm(x)\n",
    "    pad = tf.zeros([tf.shape(x)[0], pad_sz, hidden_dim])\n",
    "    x =  tf.layers.conv1d(inputs = tf.concat([pad, x, pad], 1),\n",
    "                          filters = hidden_dim,\n",
    "                          kernel_size = kernel_size,\n",
    "                          dilation_rate = dilation_rate)\n",
    "    x = x[:, :-pad_sz, :]\n",
    "    x = tf.nn.relu(x)\n",
    "    return x\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size,\n",
    "                 dict_size, learning_rate, dropout, kernel_size = 5):\n",
    "        \n",
    "        def cnn(x, scope):\n",
    "            x += position_encoding(x)\n",
    "            with tf.variable_scope(scope, reuse = tf.AUTO_REUSE):\n",
    "                for n in range(num_layers):\n",
    "                    dilation_rate = 2 ** n\n",
    "                    pad_sz = (kernel_size - 1) * dilation_rate \n",
    "                    with tf.variable_scope('block_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                        x += cnn_block(x, dilation_rate, pad_sz, size_layer, kernel_size)\n",
    "                \n",
    "                with tf.variable_scope('logits', reuse=tf.AUTO_REUSE):\n",
    "                    return tf.layers.dense(x, size_layer)[:, -1]\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        embedded_left = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        \n",
    "        self.logits = cnn(embedded_left, 'left')\n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 4\n",
    "embedded_size = 128\n",
    "learning_rate = 1e-3\n",
    "maxlen = 50\n",
    "batch_size = 128\n",
    "dropout = 0.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "\n",
    "vectors = str_idx(left, dictionary, maxlen)\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(vectors, label, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py:1702: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(size_layer,num_layers,embedded_size,len(dictionary),learning_rate,dropout)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:32<00:00, 77.42it/s, accuracy=0.584, cost=0.645]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 241.53it/s, accuracy=0.678, cost=0.624]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.72it/s, accuracy=0.664, cost=0.638]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.649024\n",
      "time taken: 35.25988554954529\n",
      "epoch: 0, training loss: 0.639172, training acc: 0.645532, valid loss: 0.625583, valid acc: 0.649024\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.78it/s, accuracy=0.653, cost=0.605]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 249.91it/s, accuracy=0.756, cost=0.568]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.57it/s, accuracy=0.648, cost=0.625]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.649024, current acc: 0.686088\n",
      "time taken: 33.05776762962341\n",
      "epoch: 0, training loss: 0.599088, training acc: 0.681601, valid loss: 0.593228, valid acc: 0.686088\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.65it/s, accuracy=0.703, cost=0.568]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.40it/s, accuracy=0.7, cost=0.548]  \n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 83.52it/s, accuracy=0.672, cost=0.615]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.686088, current acc: 0.700928\n",
      "time taken: 33.1018283367157\n",
      "epoch: 0, training loss: 0.572584, training acc: 0.705614, valid loss: 0.578908, valid acc: 0.700928\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.66it/s, accuracy=0.723, cost=0.556]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 251.01it/s, accuracy=0.778, cost=0.521]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 83.45it/s, accuracy=0.703, cost=0.604]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.700928, current acc: 0.705392\n",
      "time taken: 33.0923171043396\n",
      "epoch: 0, training loss: 0.550349, training acc: 0.723289, valid loss: 0.573883, valid acc: 0.705392\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.63it/s, accuracy=0.733, cost=0.545]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 249.63it/s, accuracy=0.767, cost=0.526]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 83.85it/s, accuracy=0.727, cost=0.582]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.705392, current acc: 0.706215\n",
      "time taken: 33.11521649360657\n",
      "epoch: 0, training loss: 0.530263, training acc: 0.737710, valid loss: 0.574223, valid acc: 0.706215\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.78it/s, accuracy=0.703, cost=0.507]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 249.87it/s, accuracy=0.722, cost=0.548]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.88it/s, accuracy=0.68, cost=0.566] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.706215, current acc: 0.712823\n",
      "time taken: 33.06076192855835\n",
      "epoch: 0, training loss: 0.512012, training acc: 0.749806, valid loss: 0.572262, valid acc: 0.712823\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.65it/s, accuracy=0.713, cost=0.539]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 249.05it/s, accuracy=0.711, cost=0.54] \n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.51it/s, accuracy=0.672, cost=0.576]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.712823, current acc: 0.715365\n",
      "time taken: 33.11697006225586\n",
      "epoch: 0, training loss: 0.495308, training acc: 0.760959, valid loss: 0.575378, valid acc: 0.715365\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.82it/s, accuracy=0.713, cost=0.55] \n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.17it/s, accuracy=0.689, cost=0.578]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.39it/s, accuracy=0.719, cost=0.558]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.715365, current acc: 0.718076\n",
      "time taken: 33.03995633125305\n",
      "epoch: 0, training loss: 0.480132, training acc: 0.770668, valid loss: 0.576161, valid acc: 0.718076\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.96it/s, accuracy=0.723, cost=0.532]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.37it/s, accuracy=0.689, cost=0.571]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 83.53it/s, accuracy=0.734, cost=0.556]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.718076, current acc: 0.718397\n",
      "time taken: 32.98867201805115\n",
      "epoch: 0, training loss: 0.466953, training acc: 0.778197, valid loss: 0.585377, valid acc: 0.718397\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.74it/s, accuracy=0.693, cost=0.579]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.47it/s, accuracy=0.722, cost=0.579]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 83.07it/s, accuracy=0.727, cost=0.532]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.718397, current acc: 0.719860\n",
      "time taken: 33.069570541381836\n",
      "epoch: 0, training loss: 0.454996, training acc: 0.786085, valid loss: 0.589913, valid acc: 0.719860\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.83it/s, accuracy=0.703, cost=0.545]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.12it/s, accuracy=0.744, cost=0.56] \n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.83it/s, accuracy=0.711, cost=0.518]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.719860, current acc: 0.722752\n",
      "time taken: 33.03630042076111\n",
      "epoch: 0, training loss: 0.443845, training acc: 0.792981, valid loss: 0.597150, valid acc: 0.722752\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.84it/s, accuracy=0.743, cost=0.536]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 249.00it/s, accuracy=0.744, cost=0.57] \n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.01it/s, accuracy=0.75, cost=0.504] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 33.04733848571777\n",
      "epoch: 0, training loss: 0.433595, training acc: 0.798370, valid loss: 0.605825, valid acc: 0.720378\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.71it/s, accuracy=0.762, cost=0.505]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 250.76it/s, accuracy=0.756, cost=0.567]\n",
      "train minibatch loop:   0%|          | 9/2527 [00:00<00:30, 82.74it/s, accuracy=0.75, cost=0.51]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 33.075902462005615\n",
      "epoch: 0, training loss: 0.423926, training acc: 0.803343, valid loss: 0.617053, valid acc: 0.721669\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2527/2527 [00:30<00:00, 82.80it/s, accuracy=0.723, cost=0.501]\n",
      "test minibatch loop: 100%|██████████| 632/632 [00:02<00:00, 251.00it/s, accuracy=0.778, cost=0.559]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 33.04087018966675\n",
      "epoch: 0, training loss: 0.415806, training acc: 0.808235, valid loss: 0.627675, valid acc: 0.719070\n",
      "\n",
      "break epoch:0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i:min(i+batch_size,train_X.shape[0])]\n",
    "        batch_y = train_Y[i:min(i+batch_size,train_X.shape[0])]\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                           feed_dict = {model.X : batch_x,\n",
    "                                        model.Y : batch_y})\n",
    "        assert not np.isnan(loss)\n",
    "        train_loss += loss\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i:min(i+batch_size,test_X.shape[0])]\n",
    "        batch_y = test_Y[i:min(i+batch_size,test_X.shape[0])]\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                           feed_dict = {model.X : batch_x,\n",
    "                                        model.Y : batch_y})\n",
    "        test_loss += loss\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "    \n",
    "    train_loss /= (len(train_X) / batch_size)\n",
    "    train_acc /= (len(train_X) / batch_size)\n",
    "    test_loss /= (len(test_X) / batch_size)\n",
    "    test_acc /= (len(test_X) / batch_size)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "    \n",
    "    print('time taken:', time.time()-lasttime)\n",
    "    print('epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'%(EPOCH,train_loss,\n",
    "                                                                                          train_acc,test_loss,\n",
    "                                                                                          test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
